{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
   "metadata": {},
   "source": [
    "# 🥙 LSTM on Recipe Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "658a95da-9645-4bcf-bd9d-4b95a4b6f582",
   "metadata": {},
   "source": [
    "In this notebook, we'll walk through the steps required to train your own LSTM on the recipes dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e0d56cc-4773-4029-97d8-26f882ba79c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import json\n",
    "import re\n",
    "import string\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers, models, callbacks, losses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
   "metadata": {},
   "source": [
    "## 0. Parameters <a name=\"parameters\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d8352af-343e-4c2e-8c91-95f8bac1c8a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "VOCAB_SIZE = 10000\n",
    "MAX_LEN = 200\n",
    "EMBEDDING_DIM = 100\n",
    "N_UNITS = 128\n",
    "VALIDATION_SPLIT = 0.2\n",
    "SEED = 42\n",
    "LOAD_MODEL = False\n",
    "BATCH_SIZE = 32\n",
    "EPOCHS = 25"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7716fac-0010-49b0-b98e-53be2259edde",
   "metadata": {},
   "source": [
    "## 1. Load the data <a name=\"load\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93cf6b0f-9667-4146-8911-763a8a2925d3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Load the full dataset\n",
    "with open(\"/app/data/epirecipes/full_format_recipes.json\") as json_data:\n",
    "    recipe_data = json.load(json_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23a74eca-f1b7-4a46-9a1f-b5806a4ed361",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Filter the dataset\n",
    "filtered_data = [\n",
    "    \"Recipe for \" + x[\"title\"] + \" | \" + \" \".join(x[\"directions\"])\n",
    "    for x in recipe_data\n",
    "    if \"title\" in x\n",
    "    and x[\"title\"] is not None\n",
    "    and \"directions\" in x\n",
    "    and x[\"directions\"] is not None\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "389c20de-0422-4c48-a7b4-6ee12a7bf0e2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Count the recipes\n",
    "n_recipes = len(filtered_data)\n",
    "print(f\"{n_recipes} recipes loaded\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b2e3cf7-e416-460e-874a-0dd9637bca36",
   "metadata": {},
   "outputs": [],
   "source": [
    "example = filtered_data[9]\n",
    "print(example)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f871aaf-d873-41c7-8946-e4eef7ac17c1",
   "metadata": {},
   "source": [
    "## 2. Tokenise the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b2064fb-5dcc-4657-b470-0928d10e2ddc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Pad the punctuation, to treat them as separate 'words'\n",
    "def pad_punctuation(s):\n",
    "    s = re.sub(f\"([{string.punctuation}])\", r\" \\1 \", s)\n",
    "    s = re.sub(\" +\", \" \", s)\n",
    "    return s\n",
    "\n",
    "\n",
    "text_data = [pad_punctuation(x) for x in filtered_data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b87d7c65-9a46-492a-a5c0-a043b0d252f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display an example of a recipe\n",
    "example_data = text_data[9]\n",
    "example_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9834f916-b21a-4104-acc9-f28d3bd7a8c1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Convert to a Tensorflow Dataset\n",
    "text_ds = (\n",
    "    tf.data.Dataset.from_tensor_slices(text_data)\n",
    "    .batch(BATCH_SIZE)\n",
    "    .shuffle(1000)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "884c0bcb-0807-45a1-8f7e-a32f2c6fa4de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a vectorisation layer\n",
    "vectorize_layer = layers.TextVectorization(\n",
    "    standardize=\"lower\",\n",
    "    max_tokens=VOCAB_SIZE,\n",
    "    output_mode=\"int\",\n",
    "    output_sequence_length=MAX_LEN + 1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d6dd34a-d905-497b-926a-405380ebcf98",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Adapt the layer to the training set\n",
    "vectorize_layer.adapt(text_ds)\n",
    "vocab = vectorize_layer.get_vocabulary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6c1c7ce-3cf0-40d4-a3dc-ab7090f69f2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display some token:word mappings\n",
    "for i, word in enumerate(vocab[:10]):\n",
    "    print(f\"{i}: {word}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cc30186-7ec6-4eb6-b29a-65df6714d321",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the same example converted to ints\n",
    "example_tokenised = vectorize_layer(example_data)\n",
    "print(example_tokenised.numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c195efb-84c6-4be0-a989-a7542188ad35",
   "metadata": {},
   "source": [
    "## 3. Create the Training Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "740294a1-1a6b-4c89-92f2-036d7d1b788b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the training set of recipes and the same text shifted by one word\n",
    "def prepare_inputs(text):\n",
    "    text = tf.expand_dims(text, -1)\n",
    "    tokenized_sentences = vectorize_layer(text)\n",
    "    x = tokenized_sentences[:, :-1]\n",
    "    y = tokenized_sentences[:, 1:]\n",
    "    return x, y\n",
    "\n",
    "\n",
    "train_ds = text_ds.map(prepare_inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 4. Build the LSTM <a name=\"build\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9230b5bf-b4a8-48d5-b73b-6899a598f296",
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = layers.Input(shape=(None,), dtype=\"int32\")\n",
    "x = layers.Embedding(VOCAB_SIZE, EMBEDDING_DIM)(inputs)\n",
    "x = layers.LSTM(N_UNITS, return_sequences=True)(x)\n",
    "outputs = layers.Dense(VOCAB_SIZE, activation=\"softmax\")(x)\n",
    "lstm = models.Model(inputs, outputs)\n",
    "lstm.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "800a3c6e-fb11-4792-b6bc-9a43a7c977ad",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if LOAD_MODEL:\n",
    "    # model.load_weights('./models/model')\n",
    "    lstm = models.load_model(\"./models/lstm\", compile=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35b14665-4359-447b-be58-3fd58ba69084",
   "metadata": {},
   "source": [
    "## 5. Train the LSTM <a name=\"train\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffb1bd3b-6fd9-4536-973e-6375bbcbf16d",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_fn = losses.SparseCategoricalCrossentropy()\n",
    "lstm.compile(\"adam\", loss_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ddcff5f-829d-4449-99d2-9a3cb68f7d72",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a TextGenerator checkpoint\n",
    "class TextGenerator(callbacks.Callback):\n",
    "    def __init__(self, index_to_word, top_k=10):\n",
    "        self.index_to_word = index_to_word\n",
    "        self.word_to_index = {\n",
    "            word: index for index, word in enumerate(index_to_word)\n",
    "        }  # <1>\n",
    "\n",
    "    def sample_from(self, probs, temperature):  # <2>\n",
    "        probs = probs ** (1 / temperature)\n",
    "        probs = probs / np.sum(probs)\n",
    "        return np.random.choice(len(probs), p=probs), probs\n",
    "\n",
    "    def generate(self, start_prompt, max_tokens, temperature):\n",
    "        start_tokens = [\n",
    "            self.word_to_index.get(x, 1) for x in start_prompt.split()\n",
    "        ]  # <3>\n",
    "        sample_token = None\n",
    "        info = []\n",
    "        while len(start_tokens) < max_tokens and sample_token != 0:  # <4>\n",
    "            x = np.array([start_tokens])\n",
    "            y = self.model.predict(x, verbose=0)  # <5>\n",
    "            sample_token, probs = self.sample_from(y[0][-1], temperature)  # <6>\n",
    "            info.append({\"prompt\": start_prompt, \"word_probs\": probs})\n",
    "            start_tokens.append(sample_token)  # <7>\n",
    "            start_prompt = start_prompt + \" \" + self.index_to_word[sample_token]\n",
    "        print(f\"\\ngenerated text:\\n{start_prompt}\\n\")\n",
    "        return info\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        self.generate(\"recipe for\", max_tokens=100, temperature=1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "349865fe-ffbe-450e-97be-043ae1740e78",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a model save checkpoint\n",
    "model_checkpoint_callback = callbacks.ModelCheckpoint(\n",
    "    filepath=\"./checkpoint/checkpoint.ckpt\",\n",
    "    save_weights_only=True,\n",
    "    save_freq=\"epoch\",\n",
    "    verbose=0,\n",
    ")\n",
    "\n",
    "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
    "\n",
    "# Tokenize starting prompt\n",
    "text_generator = TextGenerator(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "461c2b3e-b5ae-4def-8bd9-e7bab8c63d8e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "lstm.fit(\n",
    "    train_ds,\n",
    "    epochs=EPOCHS,\n",
    "    callbacks=[model_checkpoint_callback, tensorboard_callback, text_generator],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "369bde44-2e39-4bc6-8549-a3a27ecce55c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Save the final model\n",
    "lstm.save(\"./models/lstm\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d64e02d2-84dc-40c8-8446-40c09adf1e20",
   "metadata": {},
   "source": [
    "## 6. Generate text using the LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ad23adb-3ec9-4e9a-9a59-b9f9bafca649",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_probs(info, vocab, top_k=5):\n",
    "    for i in info:\n",
    "        print(f\"\\nPROMPT: {i['prompt']}\")\n",
    "        word_probs = i[\"word_probs\"]\n",
    "        p_sorted = np.sort(word_probs)[::-1][:top_k]\n",
    "        i_sorted = np.argsort(word_probs)[::-1][:top_k]\n",
    "        for p, i in zip(p_sorted, i_sorted):\n",
    "            print(f\"{vocab[i]}:   \\t{np.round(100*p,2)}%\")\n",
    "        print(\"--------\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cf25578-d47c-4b26-8252-fcdf2316a4ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "info = text_generator.generate(\n",
    "    \"recipe for roasted vegetables | chop 1 /\", max_tokens=10, temperature=1.0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9df72866-b483-4489-8e26-d5e1466410fa",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print_probs(info, vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "562e1fe8-cbcb-438f-9637-2f2a6279c924",
   "metadata": {},
   "outputs": [],
   "source": [
    "info = text_generator.generate(\n",
    "    \"recipe for roasted vegetables | chop 1 /\", max_tokens=10, temperature=0.2\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56356f21-04ac-40e5-94ff-291eca6a7054",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_probs(info, vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e434497-07f3-4989-a68d-3e31cf8fa4fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "info = text_generator.generate(\n",
    "    \"recipe for chocolate ice cream |\", max_tokens=7, temperature=1.0\n",
    ")\n",
    "print_probs(info, vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "011cd0e0-956c-4a63-8ec3-f7dfed31764e",
   "metadata": {},
   "outputs": [],
   "source": [
    "info = text_generator.generate(\n",
    "    \"recipe for chocolate ice cream |\", max_tokens=7, temperature=0.2\n",
    ")\n",
    "print_probs(info, vocab)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
